{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot import Collection\n",
    "from cot.stats import evaluation_as_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompt list\n",
    "none = [None]\n",
    "cots = [\"kojima-01\", \"kojima-03\", \"zhou-01\"]\n",
    "instructions = [\"zhou-01-ins\", \"qa-10\", \"qa-12\", \"qa-13\", \"qa-16\", \"qa-17\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ts_1k = Collection.from_json(\"ts_1k.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100 = Collection.load_thoughtsource_100()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100.select_generated_cots(author=\"thoughtsource\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge None ts_100 into ts_1000\n",
    "ts_100_filtered = ts_100.copy()\n",
    "ts_100_filtered.select_generated_cots(author=\"thoughtsource\", cot_trigger=None, instruction=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kon/work/ThoughtSource/libs/cot/cot/stats.py:406: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
      "  df.loc[dataset, (instruction + \"_\" + cot_trigger, model)] = v\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_a9b6c_row0_col1, #T_a9b6c_row1_col3, #T_a9b6c_row2_col3, #T_a9b6c_row3_col3, #T_a9b6c_row4_col3, #T_a9b6c_row5_col3, #T_a9b6c_row6_col3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_a9b6c\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_a9b6c_level0_col0\" class=\"col_heading level0 col0\" colspan=\"6\">None_None</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_a9b6c_level1_col0\" class=\"col_heading level1 col0\" >command-xlarge-nightly</th>\n",
       "      <th id=\"T_a9b6c_level1_col1\" class=\"col_heading level1 col1\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_a9b6c_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_a9b6c_level1_col3\" class=\"col_heading level1 col3\" >gpt-4</th>\n",
       "      <th id=\"T_a9b6c_level1_col4\" class=\"col_heading level1 col4\" >text-davinci-002</th>\n",
       "      <th id=\"T_a9b6c_level1_col5\" class=\"col_heading level1 col5\" >text-davinci-003</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_a9b6c_row0_col0\" class=\"data row0 col0\" >0.51</td>\n",
       "      <td id=\"T_a9b6c_row0_col1\" class=\"data row0 col1\" >0.87</td>\n",
       "      <td id=\"T_a9b6c_row0_col2\" class=\"data row0 col2\" >0.72</td>\n",
       "      <td id=\"T_a9b6c_row0_col3\" class=\"data row0 col3\" >0.75</td>\n",
       "      <td id=\"T_a9b6c_row0_col4\" class=\"data row0 col4\" >0.76</td>\n",
       "      <td id=\"T_a9b6c_row0_col5\" class=\"data row0 col5\" >0.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_a9b6c_row1_col0\" class=\"data row1 col0\" >0.23</td>\n",
       "      <td id=\"T_a9b6c_row1_col1\" class=\"data row1 col1\" >0.32</td>\n",
       "      <td id=\"T_a9b6c_row1_col2\" class=\"data row1 col2\" >0.58</td>\n",
       "      <td id=\"T_a9b6c_row1_col3\" class=\"data row1 col3\" >0.73</td>\n",
       "      <td id=\"T_a9b6c_row1_col4\" class=\"data row1 col4\" >0.41</td>\n",
       "      <td id=\"T_a9b6c_row1_col5\" class=\"data row1 col5\" >0.43</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_a9b6c_row2_col0\" class=\"data row2 col0\" >0.25</td>\n",
       "      <td id=\"T_a9b6c_row2_col1\" class=\"data row2 col1\" >0.34</td>\n",
       "      <td id=\"T_a9b6c_row2_col2\" class=\"data row2 col2\" >0.58</td>\n",
       "      <td id=\"T_a9b6c_row2_col3\" class=\"data row2 col3\" >0.69</td>\n",
       "      <td id=\"T_a9b6c_row2_col4\" class=\"data row2 col4\" >0.34</td>\n",
       "      <td id=\"T_a9b6c_row2_col5\" class=\"data row2 col5\" >0.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_a9b6c_row3_col0\" class=\"data row3 col0\" >0.59</td>\n",
       "      <td id=\"T_a9b6c_row3_col1\" class=\"data row3 col1\" >0.82</td>\n",
       "      <td id=\"T_a9b6c_row3_col2\" class=\"data row3 col2\" >0.77</td>\n",
       "      <td id=\"T_a9b6c_row3_col3\" class=\"data row3 col3\" >0.92</td>\n",
       "      <td id=\"T_a9b6c_row3_col4\" class=\"data row3 col4\" >0.67</td>\n",
       "      <td id=\"T_a9b6c_row3_col5\" class=\"data row3 col5\" >0.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_a9b6c_row4_col0\" class=\"data row4 col0\" >0.52</td>\n",
       "      <td id=\"T_a9b6c_row4_col1\" class=\"data row4 col1\" >0.61</td>\n",
       "      <td id=\"T_a9b6c_row4_col2\" class=\"data row4 col2\" >0.57</td>\n",
       "      <td id=\"T_a9b6c_row4_col3\" class=\"data row4 col3\" >0.71</td>\n",
       "      <td id=\"T_a9b6c_row4_col4\" class=\"data row4 col4\" >0.36</td>\n",
       "      <td id=\"T_a9b6c_row4_col5\" class=\"data row4 col5\" >0.53</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row5\" class=\"row_heading level0 row5\" >worldtree</th>\n",
       "      <td id=\"T_a9b6c_row5_col0\" class=\"data row5 col0\" >0.63</td>\n",
       "      <td id=\"T_a9b6c_row5_col1\" class=\"data row5 col1\" >0.85</td>\n",
       "      <td id=\"T_a9b6c_row5_col2\" class=\"data row5 col2\" >0.96</td>\n",
       "      <td id=\"T_a9b6c_row5_col3\" class=\"data row5 col3\" >0.99</td>\n",
       "      <td id=\"T_a9b6c_row5_col4\" class=\"data row5 col4\" >0.88</td>\n",
       "      <td id=\"T_a9b6c_row5_col5\" class=\"data row5 col5\" >0.91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_a9b6c_level0_row6\" class=\"row_heading level0 row6\" >Average</th>\n",
       "      <td id=\"T_a9b6c_row6_col0\" class=\"data row6 col0\" >0.46</td>\n",
       "      <td id=\"T_a9b6c_row6_col1\" class=\"data row6 col1\" >0.64</td>\n",
       "      <td id=\"T_a9b6c_row6_col2\" class=\"data row6 col2\" >0.70</td>\n",
       "      <td id=\"T_a9b6c_row6_col3\" class=\"data row6 col3\" >0.80</td>\n",
       "      <td id=\"T_a9b6c_row6_col4\" class=\"data row6 col4\" >0.57</td>\n",
       "      <td id=\"T_a9b6c_row6_col5\" class=\"data row6 col5\" >0.62</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fedcf8fe860>"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = ts_100_filtered.evaluate()\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_1k = ts_1k.merge(ts_100_None)\n",
    "ts_1k.dump(\"ts_1k\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "cots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge cots ts_100 into ts_1000\n",
    "ts_100_filtered = ts_100.copy()\n",
    "ts_100_filtered.select_generated_cots(author=\"thoughtsource\", cot_trigger=cots, instruction=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kon/work/ThoughtSource/libs/cot/cot/stats.py:406: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
      "  df.loc[dataset, (instruction + \"_\" + cot_trigger, model)] = v\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_d2ee3_row0_col6, #T_d2ee3_row1_col8, #T_d2ee3_row2_col8, #T_d2ee3_row3_col8, #T_d2ee3_row4_col8, #T_d2ee3_row5_col8, #T_d2ee3_row6_col8 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_d2ee3\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_d2ee3_level0_col0\" class=\"col_heading level0 col0\" colspan=\"5\">None_kojima-01</th>\n",
       "      <th id=\"T_d2ee3_level0_col5\" class=\"col_heading level0 col5\" >None_kojima-03</th>\n",
       "      <th id=\"T_d2ee3_level0_col6\" class=\"col_heading level0 col6\" colspan=\"3\">None_zhou-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_d2ee3_level1_col0\" class=\"col_heading level1 col0\" >command-xlarge-nightly</th>\n",
       "      <th id=\"T_d2ee3_level1_col1\" class=\"col_heading level1 col1\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_d2ee3_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d2ee3_level1_col3\" class=\"col_heading level1 col3\" >text-davinci-002</th>\n",
       "      <th id=\"T_d2ee3_level1_col4\" class=\"col_heading level1 col4\" >text-davinci-003</th>\n",
       "      <th id=\"T_d2ee3_level1_col5\" class=\"col_heading level1 col5\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d2ee3_level1_col6\" class=\"col_heading level1 col6\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_d2ee3_level1_col7\" class=\"col_heading level1 col7\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_d2ee3_level1_col8\" class=\"col_heading level1 col8\" >gpt-4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_d2ee3_row0_col0\" class=\"data row0 col0\" >0.53</td>\n",
       "      <td id=\"T_d2ee3_row0_col1\" class=\"data row0 col1\" >0.81</td>\n",
       "      <td id=\"T_d2ee3_row0_col2\" class=\"data row0 col2\" >0.67</td>\n",
       "      <td id=\"T_d2ee3_row0_col3\" class=\"data row0 col3\" >0.62</td>\n",
       "      <td id=\"T_d2ee3_row0_col4\" class=\"data row0 col4\" >0.65</td>\n",
       "      <td id=\"T_d2ee3_row0_col5\" class=\"data row0 col5\" >0.63</td>\n",
       "      <td id=\"T_d2ee3_row0_col6\" class=\"data row0 col6\" >0.83</td>\n",
       "      <td id=\"T_d2ee3_row0_col7\" class=\"data row0 col7\" >0.66</td>\n",
       "      <td id=\"T_d2ee3_row0_col8\" class=\"data row0 col8\" >0.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_d2ee3_row1_col0\" class=\"data row1 col0\" >0.31</td>\n",
       "      <td id=\"T_d2ee3_row1_col1\" class=\"data row1 col1\" >0.34</td>\n",
       "      <td id=\"T_d2ee3_row1_col2\" class=\"data row1 col2\" >0.59</td>\n",
       "      <td id=\"T_d2ee3_row1_col3\" class=\"data row1 col3\" >0.34</td>\n",
       "      <td id=\"T_d2ee3_row1_col4\" class=\"data row1 col4\" >0.43</td>\n",
       "      <td id=\"T_d2ee3_row1_col5\" class=\"data row1 col5\" >0.59</td>\n",
       "      <td id=\"T_d2ee3_row1_col6\" class=\"data row1 col6\" >0.27</td>\n",
       "      <td id=\"T_d2ee3_row1_col7\" class=\"data row1 col7\" >0.65</td>\n",
       "      <td id=\"T_d2ee3_row1_col8\" class=\"data row1 col8\" >0.76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_d2ee3_row2_col0\" class=\"data row2 col0\" >0.22</td>\n",
       "      <td id=\"T_d2ee3_row2_col1\" class=\"data row2 col1\" >0.35</td>\n",
       "      <td id=\"T_d2ee3_row2_col2\" class=\"data row2 col2\" >0.47</td>\n",
       "      <td id=\"T_d2ee3_row2_col3\" class=\"data row2 col3\" >0.34</td>\n",
       "      <td id=\"T_d2ee3_row2_col4\" class=\"data row2 col4\" >0.36</td>\n",
       "      <td id=\"T_d2ee3_row2_col5\" class=\"data row2 col5\" >0.50</td>\n",
       "      <td id=\"T_d2ee3_row2_col6\" class=\"data row2 col6\" >0.31</td>\n",
       "      <td id=\"T_d2ee3_row2_col7\" class=\"data row2 col7\" >0.48</td>\n",
       "      <td id=\"T_d2ee3_row2_col8\" class=\"data row2 col8\" >0.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_d2ee3_row3_col0\" class=\"data row3 col0\" >0.38</td>\n",
       "      <td id=\"T_d2ee3_row3_col1\" class=\"data row3 col1\" >0.79</td>\n",
       "      <td id=\"T_d2ee3_row3_col2\" class=\"data row3 col2\" >0.77</td>\n",
       "      <td id=\"T_d2ee3_row3_col3\" class=\"data row3 col3\" >0.57</td>\n",
       "      <td id=\"T_d2ee3_row3_col4\" class=\"data row3 col4\" >0.67</td>\n",
       "      <td id=\"T_d2ee3_row3_col5\" class=\"data row3 col5\" >0.73</td>\n",
       "      <td id=\"T_d2ee3_row3_col6\" class=\"data row3 col6\" >0.81</td>\n",
       "      <td id=\"T_d2ee3_row3_col7\" class=\"data row3 col7\" >0.81</td>\n",
       "      <td id=\"T_d2ee3_row3_col8\" class=\"data row3 col8\" >0.95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_d2ee3_row4_col0\" class=\"data row4 col0\" >0.59</td>\n",
       "      <td id=\"T_d2ee3_row4_col1\" class=\"data row4 col1\" >0.69</td>\n",
       "      <td id=\"T_d2ee3_row4_col2\" class=\"data row4 col2\" >0.56</td>\n",
       "      <td id=\"T_d2ee3_row4_col3\" class=\"data row4 col3\" >0.38</td>\n",
       "      <td id=\"T_d2ee3_row4_col4\" class=\"data row4 col4\" >0.54</td>\n",
       "      <td id=\"T_d2ee3_row4_col5\" class=\"data row4 col5\" >0.52</td>\n",
       "      <td id=\"T_d2ee3_row4_col6\" class=\"data row4 col6\" >0.59</td>\n",
       "      <td id=\"T_d2ee3_row4_col7\" class=\"data row4 col7\" >0.59</td>\n",
       "      <td id=\"T_d2ee3_row4_col8\" class=\"data row4 col8\" >0.80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row5\" class=\"row_heading level0 row5\" >worldtree</th>\n",
       "      <td id=\"T_d2ee3_row5_col0\" class=\"data row5 col0\" >0.58</td>\n",
       "      <td id=\"T_d2ee3_row5_col1\" class=\"data row5 col1\" >0.80</td>\n",
       "      <td id=\"T_d2ee3_row5_col2\" class=\"data row5 col2\" >0.93</td>\n",
       "      <td id=\"T_d2ee3_row5_col3\" class=\"data row5 col3\" >0.78</td>\n",
       "      <td id=\"T_d2ee3_row5_col4\" class=\"data row5 col4\" >0.89</td>\n",
       "      <td id=\"T_d2ee3_row5_col5\" class=\"data row5 col5\" >0.95</td>\n",
       "      <td id=\"T_d2ee3_row5_col6\" class=\"data row5 col6\" >0.83</td>\n",
       "      <td id=\"T_d2ee3_row5_col7\" class=\"data row5 col7\" >0.92</td>\n",
       "      <td id=\"T_d2ee3_row5_col8\" class=\"data row5 col8\" >0.99</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d2ee3_level0_row6\" class=\"row_heading level0 row6\" >Average</th>\n",
       "      <td id=\"T_d2ee3_row6_col0\" class=\"data row6 col0\" >0.44</td>\n",
       "      <td id=\"T_d2ee3_row6_col1\" class=\"data row6 col1\" >0.63</td>\n",
       "      <td id=\"T_d2ee3_row6_col2\" class=\"data row6 col2\" >0.66</td>\n",
       "      <td id=\"T_d2ee3_row6_col3\" class=\"data row6 col3\" >0.50</td>\n",
       "      <td id=\"T_d2ee3_row6_col4\" class=\"data row6 col4\" >0.59</td>\n",
       "      <td id=\"T_d2ee3_row6_col5\" class=\"data row6 col5\" >0.65</td>\n",
       "      <td id=\"T_d2ee3_row6_col6\" class=\"data row6 col6\" >0.61</td>\n",
       "      <td id=\"T_d2ee3_row6_col7\" class=\"data row6 col7\" >0.68</td>\n",
       "      <td id=\"T_d2ee3_row6_col8\" class=\"data row6 col8\" >0.82</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fed1c0dcf10>"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = ts_100_filtered.evaluate()\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2e33f9205d644bfb242077a3b915e49",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f24d1fcfb0674c93b4d4f0eba7256c2e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10477d4f15404b1cb7f9e3cea89e71cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39cb3a6ad7e54b1a95075201c17536a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6728f292560345d1a65d05635b187851",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b78ae995a8a463c98d90dca10014b0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ts_1k = ts_1k.merge(ts_100_filtered)\n",
    "ts_1k.dump(\"ts_1k\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "instructions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge cots ts_100 into ts_1000\n",
    "ts_100_filtered = ts_100.copy()\n",
    "ts_100_filtered.select_generated_cots(author=\"thoughtsource\", cot_trigger=None, instruction=instructions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_c28ba_row0_col5, #T_c28ba_row1_col3, #T_c28ba_row2_col1, #T_c28ba_row3_col1, #T_c28ba_row4_col2, #T_c28ba_row5_col2, #T_c28ba_row6_col0 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_c28ba\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_c28ba_level0_col0\" class=\"col_heading level0 col0\" >qa-10_None</th>\n",
       "      <th id=\"T_c28ba_level0_col1\" class=\"col_heading level0 col1\" >qa-12_None</th>\n",
       "      <th id=\"T_c28ba_level0_col2\" class=\"col_heading level0 col2\" >qa-13_None</th>\n",
       "      <th id=\"T_c28ba_level0_col3\" class=\"col_heading level0 col3\" >qa-16_None</th>\n",
       "      <th id=\"T_c28ba_level0_col4\" class=\"col_heading level0 col4\" >qa-17_None</th>\n",
       "      <th id=\"T_c28ba_level0_col5\" class=\"col_heading level0 col5\" >zhou-01-ins_None</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_c28ba_level1_col0\" class=\"col_heading level1 col0\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_c28ba_level1_col1\" class=\"col_heading level1 col1\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_c28ba_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_c28ba_level1_col3\" class=\"col_heading level1 col3\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_c28ba_level1_col4\" class=\"col_heading level1 col4\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_c28ba_level1_col5\" class=\"col_heading level1 col5\" >gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_c28ba_row0_col0\" class=\"data row0 col0\" >0.68</td>\n",
       "      <td id=\"T_c28ba_row0_col1\" class=\"data row0 col1\" >0.63</td>\n",
       "      <td id=\"T_c28ba_row0_col2\" class=\"data row0 col2\" >0.61</td>\n",
       "      <td id=\"T_c28ba_row0_col3\" class=\"data row0 col3\" >0.58</td>\n",
       "      <td id=\"T_c28ba_row0_col4\" class=\"data row0 col4\" >0.66</td>\n",
       "      <td id=\"T_c28ba_row0_col5\" class=\"data row0 col5\" >0.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_c28ba_row1_col0\" class=\"data row1 col0\" >0.56</td>\n",
       "      <td id=\"T_c28ba_row1_col1\" class=\"data row1 col1\" >0.59</td>\n",
       "      <td id=\"T_c28ba_row1_col2\" class=\"data row1 col2\" >0.49</td>\n",
       "      <td id=\"T_c28ba_row1_col3\" class=\"data row1 col3\" >0.60</td>\n",
       "      <td id=\"T_c28ba_row1_col4\" class=\"data row1 col4\" >0.56</td>\n",
       "      <td id=\"T_c28ba_row1_col5\" class=\"data row1 col5\" >0.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_c28ba_row2_col0\" class=\"data row2 col0\" >0.49</td>\n",
       "      <td id=\"T_c28ba_row2_col1\" class=\"data row2 col1\" >0.53</td>\n",
       "      <td id=\"T_c28ba_row2_col2\" class=\"data row2 col2\" >0.41</td>\n",
       "      <td id=\"T_c28ba_row2_col3\" class=\"data row2 col3\" >0.48</td>\n",
       "      <td id=\"T_c28ba_row2_col4\" class=\"data row2 col4\" >0.53</td>\n",
       "      <td id=\"T_c28ba_row2_col5\" class=\"data row2 col5\" >0.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_c28ba_row3_col0\" class=\"data row3 col0\" >0.73</td>\n",
       "      <td id=\"T_c28ba_row3_col1\" class=\"data row3 col1\" >0.80</td>\n",
       "      <td id=\"T_c28ba_row3_col2\" class=\"data row3 col2\" >0.72</td>\n",
       "      <td id=\"T_c28ba_row3_col3\" class=\"data row3 col3\" >0.69</td>\n",
       "      <td id=\"T_c28ba_row3_col4\" class=\"data row3 col4\" >0.69</td>\n",
       "      <td id=\"T_c28ba_row3_col5\" class=\"data row3 col5\" >0.76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_c28ba_row4_col0\" class=\"data row4 col0\" >0.56</td>\n",
       "      <td id=\"T_c28ba_row4_col1\" class=\"data row4 col1\" >0.50</td>\n",
       "      <td id=\"T_c28ba_row4_col2\" class=\"data row4 col2\" >0.64</td>\n",
       "      <td id=\"T_c28ba_row4_col3\" class=\"data row4 col3\" >0.63</td>\n",
       "      <td id=\"T_c28ba_row4_col4\" class=\"data row4 col4\" >0.58</td>\n",
       "      <td id=\"T_c28ba_row4_col5\" class=\"data row4 col5\" >0.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row5\" class=\"row_heading level0 row5\" >worldtree</th>\n",
       "      <td id=\"T_c28ba_row5_col0\" class=\"data row5 col0\" >0.95</td>\n",
       "      <td id=\"T_c28ba_row5_col1\" class=\"data row5 col1\" >0.92</td>\n",
       "      <td id=\"T_c28ba_row5_col2\" class=\"data row5 col2\" >0.96</td>\n",
       "      <td id=\"T_c28ba_row5_col3\" class=\"data row5 col3\" >0.91</td>\n",
       "      <td id=\"T_c28ba_row5_col4\" class=\"data row5 col4\" >0.92</td>\n",
       "      <td id=\"T_c28ba_row5_col5\" class=\"data row5 col5\" >0.96</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c28ba_level0_row6\" class=\"row_heading level0 row6\" >Average</th>\n",
       "      <td id=\"T_c28ba_row6_col0\" class=\"data row6 col0\" >0.66</td>\n",
       "      <td id=\"T_c28ba_row6_col1\" class=\"data row6 col1\" >0.66</td>\n",
       "      <td id=\"T_c28ba_row6_col2\" class=\"data row6 col2\" >0.64</td>\n",
       "      <td id=\"T_c28ba_row6_col3\" class=\"data row6 col3\" >0.65</td>\n",
       "      <td id=\"T_c28ba_row6_col4\" class=\"data row6 col4\" >0.66</td>\n",
       "      <td id=\"T_c28ba_row6_col5\" class=\"data row6 col5\" >0.66</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fedcfd444c0>"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = ts_100_filtered.evaluate()\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fccef7b62dc24cd1b6458db2a91eadf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9aa9468c283c4eb788b1ccd122cea4ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4457e5ef7009435daba35a0efcc20a8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4fcb340ab212431e827a3267b62fd3b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2890bc614bed4efb807f93a707d1f797",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb54a31390ef41fe8cb8d0498d558908",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ts_1k = ts_1k.merge(ts_100_filtered)\n",
    "ts_1k.dump(\"ts_1k\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "commonsense_qa {1, 22}\n",
      "med_qa {1, 22}\n",
      "medmc_qa {1, 22}\n",
      "open_book_qa {1, 22}\n",
      "strategy_qa {1, 22}\n",
      "worldtree {1, 22}\n"
     ]
    }
   ],
   "source": [
    "ts_1k.number_generated_cots()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
